# 负责处理神经网络需要的数据 整合数据

import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    def __init__(self, datas):
        super(MyDataSet, self).__init__()
        self.datas = datas

    def __len__(self):
        return len(self.datas)

    # 尽可能简单，collect会处理
    def __getitem__(self, item):
        single = self.datas[item]
        return single

def sentence2vec(sentence, token2idx):
    tokens = torch.LongTensor()
    for token in sentence:
        tokens = torch.cat([tokens, torch.LongTensor([token2idx[token]])])
    return tokens

def sentence_filter(datas, long_limit):
    datas_f = []
    for sentence in datas:
        stc = sentence['sentence']
        if len(stc) < long_limit:
            datas_f += [sentence]
    return datas_f

# 新数据集组织函数不需要额外的collate_fn函数
class Seq2seqDataSet(Dataset):
    def __init__(self, datas, token2idx, long_limit):
        super(Seq2seqDataSet, self).__init__()
        self.datas = sentence_filter(datas, long_limit)
        self.token2idx = token2idx
        self.long_limit = long_limit

    def __len__(self):
        return (len(self.datas)-1)

    def __getitem__(self, item):
        q = sentence2vec(self.datas[item]['sentence'], self.token2idx)
        a = sentence2vec(self.datas[item + 1]['sentence'], self.token2idx)

        if len(q) < self.long_limit:
            q = torch.cat([q, torch.LongTensor([0]*(self.long_limit - len(q)))])
        
        if len(a) < self.long_limit:
            a = torch.cat([a, torch.LongTensor([0]*(self.long_limit - len(a)))])
        
        return q, a